function utils = eval_utils()
% Contains evaluation functionality.
% See registries_test.m and registries_eval.m for example uses.

utils.run_baseline = @run_baseline;
utils.run_comparisons = @run_comparisons;
utils.quantile_results = @quantile_results;

utils.compare_learned_Ks = @compare_learned_Ks;
utils.compare_learned_Ks_likelihoods = @compare_learned_Ks_likelihoods;

utils.eval_target_type = @eval_target_type;


function [K_ka, obj_vals_ka, ka_time] = ...
  run_baseline(T_train, ka_opts, utils, K_start, base_type)
if strcmp(base_type, 'pg') || strcmp(base_type, 'eg')
  tic;
  [K_ka, obj_vals_ka] = ...
    K_ascent(T_train, ka_opts, utils, K_start, base_type);
  ka_time = toc;
else
  if strcmp(base_type, 'diag')
    tic;
    % Builds the likelihood-maximizing diagonal kernel.
    M = utils.get_low_moments(T_train);
    M = diag(diag(M));
    K_ka.D = M;
    K_ka.M = M;
    K_ka.V = eye(size(M));
    obj_vals_ka = utils.K_log_likelihood(T_train, K_ka.M);
    ka_time = toc;
  else
    if strcmp(base_type, 'mm')
      tic;
      K_ka = utils.K_moments_init(T_train, 1);
      obj_vals_ka = utils.K_log_likelihood(T_train, K_ka.M);
      ka_time = toc;
    else
      throw(MException('eval_utils:BadType', ...
        ['Recognized types are projected gradient (pg),' ...
         'exponentiated gradient (eg), and diagonal (diag).']));
    end
  end
end


function stats = run_comparisons(K, K_start, K_em, ...
  obj_vals_em, em_time, K_ka, obj_vals_ka, ka_time, T_train, T_test)
num_iters_em = numel(obj_vals_em);
num_iters_ka = numel(obj_vals_ka);
[kls, kl_diffs] = compare_learned_Ks(K, K_start, K_em, K_ka);
[lls, ll_diffs] = ...
  compare_learned_Ks_likelihoods(K, K_start, K_em, K_ka, ...
  T_train, T_test);
stats = [kls, lls, em_time, ka_time, ...
  num_iters_em, num_iters_ka, ll_diffs, kl_diffs, ...
  (em_time - ka_time) / ka_time, ...
  (num_iters_em - num_iters_ka) / num_iters_ka, ...
  em_time / num_iters_em, ka_time / num_iters_ka];


function quantiles = quantile_results(stats)
[num_T_sizes, num_stats, num_trials] = size(stats);
quantiles = zeros(num_T_sizes, num_stats, 3);
for T_num = 1:num_T_sizes
  quantiles(T_num, :, :) = ...
    prctile(reshape(stats(T_num, :, :), num_stats, num_trials), ...
      [25, 50, 75], 2);
end


function [kls, kl_diffs] = compare_learned_Ks(K, K_start, K_em, K_ka)
N = size(K, 1);
num_kl_samples = 10 * N;
kl_div = @(K1, K2) dpp_kl_div(K1, K2, num_kl_samples);

em_true_Kkl = kl_div(K, K_em);
em_true_Kkl_diag = kl_div(K, decompose(diag(diag(K_em.M))));
em_start_Kkl = kl_div(K_start, K_em);
ka_true_Kkl = kl_div(K, K_ka);
ka_true_Kkl_diag = kl_div(K, decompose(diag(diag(K_ka.M))));
ka_start_Kkl = kl_div(K_start, K_ka);
start_true_Kkl = kl_div(K, K_start);
true_Kkl_diag = kl_div(K, decompose(diag(diag(K.M))));

kls = [em_true_Kkl, ...
  em_true_Kkl_diag, ...
  em_start_Kkl, ...
  ka_true_Kkl, ...
  ka_true_Kkl_diag, ...
  ka_start_Kkl, ...
  start_true_Kkl, ...
  true_Kkl_diag];

kl_diffs = [em_true_Kkl - ka_true_Kkl, ...
  em_true_Kkl - start_true_Kkl, ...
  em_true_Kkl - em_true_Kkl_diag, ...
  ka_true_Kkl - start_true_Kkl, ...
  ka_true_Kkl - ka_true_Kkl_diag];
  

function [lls, ll_diffs] = ...
  compare_learned_Ks_likelihoods(K, K_start, K_em, K_ka, T_train, T_test)
utils = opt_utils();
ll_func = @(T1, K1) utils.K_log_likelihood(T1, K1);
%ll_func = @(T1, K1) utils.median_K_log_likelihood(T1, K1);

ll_em = ll_func(T_train, K_em.M);
ll_test_em = ll_func(T_test, K_em.M);
ll_test_em_diag = ll_func(T_test, diag(diag(K_em.M)));
ll_ka = ll_func(T_train, K_ka.M);
ll_test_ka = ll_func(T_test, K_ka.M);
ll_test_ka_diag = ll_func(T_test, diag(diag(K_ka.M)));
ll_start = ll_func(T_train, K_start.M);
ll_test_start = ll_func(T_test, K_start.M);
ll_test_start_diag = ll_func(T_test, diag(diag(K_start.M)));
ll_true = ll_func(T_train, K.M);
ll_test_true = ll_func(T_test, K.M);
ll_test_diag = ll_func(T_test, diag(diag(K.M)));

lls = [ll_em, ...
  ll_test_em, ...
  ll_test_em_diag, ... 
  ll_ka, ... 
  ll_test_ka, ...
  ll_test_ka_diag, ...
  ll_start, ...
  ll_test_start, ...
  ll_test_start_diag, ...
  ll_true, ...
  ll_test_true, ...
  ll_test_diag];

ll_diffs = [(ll_em - ll_ka) / abs(ll_true), ...
  (ll_em - ll_start) / abs(ll_true), ...
  (ll_ka - ll_start) / abs(ll_true), ...
  (ll_test_em - ll_test_ka) / abs(ll_test_true), ...
  (ll_test_em - ll_test_start) / abs(ll_test_true), ...
  (ll_test_em - ll_test_em_diag) / abs(ll_test_true), ...
  (ll_test_ka - ll_test_start) / abs(ll_test_true), ...
  (ll_test_ka - ll_test_ka_diag) / abs(ll_test_true), ...
  (ll_test_em - ll_test_true) / abs(ll_test_true), ...
  (ll_test_em - ll_test_diag) / abs(ll_test_true), ...
  (ll_test_true - ll_test_diag) / abs(ll_test_true)];

  
function eval_target_type(r, T_sizes, indices, sum_str)
T_num = 0;
for T_size = T_sizes
  T_num = T_num + 1;
  fstr = [' (%.4f) %.4f (%.4f) ---- ' sprintf('T_size = %d; ', T_size) ...
    sum_str '\n'];
  out_strs = {['em_true_Kkl' fstr], ... % 1
    ['em_true_Kkl_diag' fstr], ...
    ['em_start_Kkl' fstr], ...
    ['ka_true_Kkl' fstr], ...
    ['ka_true_Kkl_diag' fstr], ...
    ['ka_start_Kkl' fstr], ...
    ['start_true_Kkl' fstr], ...
    ['true_Kkl_diag' fstr], ...
    ['ll_em' fstr], ... % 9
    ['ll_test_em' fstr], ...
    ['ll_test_em_diag' fstr], ...
    ['ll_ka' fstr], ...
    ['ll_test_ka' fstr], ...
    ['ll_test_ka_diag' fstr], ...
    ['ll_start' fstr], ...
    ['ll_test_start' fstr], ...
    ['ll_test_start_diag' fstr], ...
    ['ll_true' fstr], ...
    ['ll_test_true' fstr], ...
    ['ll_test_diag' fstr], ...
    ['em_time' fstr], ... % 21
    ['ka_time' fstr], ...
    ['num_iters_em' fstr], ...
    ['num_iters_ka' fstr], ...
    ['** (ll_em - ll_ka) / abs(ll_true)' fstr], ... % 25
    ['(ll_em - ll_start) / abs(ll_true)' fstr], ...
    ['(ll_ka - ll_start) / abs(ll_true)' fstr], ...
    ['** (ll_test_em - ll_test_ka) / abs(ll_test_true)' fstr], ...
    ['! (ll_test_em - ll_test_start) / abs(ll_test_true)' fstr], ...
    ['! (ll_test_em - ll_test_em_diag) / abs(ll_test_true)' fstr], ...
    ['! (ll_test_ka - ll_test_start) / abs(ll_test_true)' fstr], ...
    ['! (ll_test_ka - ll_test_ka_diag) / abs(ll_test_true)' fstr], ...
    ['(ll_test_em - ll_test_true) / abs(ll_test_true)' fstr], ...
    ['(ll_test_em - ll_test_diag) / abs(ll_test_true)' fstr], ...
    ['** (ll_test_true - ll_test_diag) / abs(ll_test_true)' fstr], ...
    ['** em_true_Kkl - ka_true_Kkl' fstr], ... % 36
    ['! em_true_Kkl - start_true_Kkl' fstr], ...
    ['! em_true_Kkl - em_true_Kkl_diag' fstr], ...
    ['! ka_true_Kkl - start_true_Kkl' fstr], ...
    ['! ka_true_Kkl - ka_true_Kkl_diag' fstr], ...
    ['(em_time - ka_time) / ka_time' fstr], ... % 41
    ['(num_iters_em - num_iters_ka) / num_iters_ka' fstr], ...
    ['em_time / num_iters_em' fstr], ...
    ['ka_time / num_iters_ka' fstr]}; % 44 

  fprintf([out_strs{indices} '\n'], ...
    reshape(r.quantiles(T_num, :, :), [], 3)');
end
